package com.tca.common.data.mybatis.handler;

import com.tca.common.core.bean.DataScope;
import com.tca.common.core.datapermission.DataPermissionContextHolder;
import com.tca.common.core.enums.DataScopeTypeEnum;
import com.tca.common.core.utils.ValidateUtils;
import com.tca.common.data.mybatis.config.DataPermissionProperties;
import com.tca.common.data.mybatis.utils.CamelCaseUtils;
import lombok.extern.slf4j.Slf4j;
import net.sf.jsqlparser.expression.Expression;
import net.sf.jsqlparser.expression.HexValue;
import net.sf.jsqlparser.expression.LongValue;
import net.sf.jsqlparser.expression.operators.conditional.AndExpression;
import net.sf.jsqlparser.expression.operators.relational.EqualsTo;
import net.sf.jsqlparser.expression.operators.relational.ExpressionList;
import net.sf.jsqlparser.expression.operators.relational.InExpression;
import net.sf.jsqlparser.expression.operators.relational.ItemsList;
import net.sf.jsqlparser.schema.Column;
import org.apache.ibatis.reflection.DefaultReflectorFactory;
import org.apache.ibatis.reflection.MetaClass;
import org.apache.ibatis.reflection.ReflectorFactory;
import org.springframework.util.PatternMatchUtils;

import java.lang.reflect.ParameterizedType;
import java.util.*;
import java.util.concurrent.ConcurrentHashMap;
import java.util.stream.Collectors;

import static com.tca.common.core.enums.DataScopeTypeEnum.ALL;


/**
 * @author zhoua
 * @date 2022/2/14 16:21
 */
@Slf4j
public class DataPermissionHandler implements com.baomidou.mybatisplus.extension.plugins.handler.DataPermissionHandler {
    
    private DataPermissionContextHolder dataPermissionContextHolder;
    
    private DataPermissionProperties dataPermissionProperties;
    
    /**
     * ignoreTableNames使用大写, 去下划线
     */
    private List<String> ignoreTableNames;
    
    /**
     * 存储 mappedStatementId - tableName(tableName使用大写, 去下划线)
     */
    private Map<String, String> mappedStatementIdTableName = new ConcurrentHashMap<>();
    
    /**
     * 存储 className - <field, boolean>
     * 根据 class 是否包含 field属性
     */
    private Map<String, Map<String, Boolean>> classFieldExistsMap = new ConcurrentHashMap<>();

    /**
     * 反射工厂
     */
    private ReflectorFactory reflectorFactory = new DefaultReflectorFactory();
    
    public DataPermissionHandler(DataPermissionContextHolder dataPermissionContextHolder, DataPermissionProperties dataPermissionProperties) {
        this.dataPermissionContextHolder = dataPermissionContextHolder;
        this.dataPermissionProperties = dataPermissionProperties;
        
        initIgnoreTables(dataPermissionProperties);
    }
    
    private void initIgnoreTables(DataPermissionProperties dataPermissionProperties) {
        List<String> ignoreTables = dataPermissionProperties.getIgnoreTables();
        if (ValidateUtils.isEmpty(ignoreTables)) {
            this.ignoreTableNames = new ArrayList<>();
        } else {
            this.ignoreTableNames = ignoreTables.stream().map(s -> s.replaceAll("_", "")
                    .toUpperCase(Locale.ENGLISH)).collect(Collectors.toList());
        }
    }
    
    /**
     * @param where             原SQL Where 条件表达式
     * @param mappedStatementId Mapper接口方法ID
     * @return
     */
    @Override
    public Expression getSqlSegment(Expression where, String mappedStatementId) {
        
        // 忽略表
        if (tableIgnore(mappedStatementId)) {
            return where;
        }
        
        // 忽略方法
        if (methodIgnore(mappedStatementId)) {
            return where;
        }
    
        // 忽略无租户字段
        if (ignoreWithoutField(mappedStatementId, CamelCaseUtils.toCamelCaseField(dataPermissionProperties.getTenantColumn()))) {
            return where;
        }
        
        // 获取用户权限
        DataScope dataScope = dataPermissionContextHolder.getDataScope();
        
        if (Objects.isNull(dataScope)) {
            return where;
        }
        Integer dataScopeTypeCode = dataScope.getDataScopeTypeCode();
        DataScopeTypeEnum dataScopeTypeEnum = DataScopeTypeEnum.getByCode(dataScopeTypeCode);
    
        try {
            log.debug("开始进行权限过滤, dataFilterMetaData:{}, where: {}, mappedStatementId: {}",
                    dataScope, where, mappedStatementId);
            // 查看全部
            if (dataScopeTypeEnum == ALL) {
                return where;
            }
            Expression expression = new HexValue(" 1 = 1 ");
            if (where == null) {
                where = expression;
            }
            switch (dataScopeTypeEnum) {
                // 查看租户级别
                case TENANT:
                    EqualsTo equalsTo = new EqualsTo();
                    equalsTo.setLeftExpression(new Column(dataPermissionProperties.getTenantColumn()));
                    equalsTo.setRightExpression(new LongValue(dataScope.getTenant()));
                    return new AndExpression(where, equalsTo);
                    
                // 查看部门级别
                case DEPARTMENT:
                    // 商户查询
                    EqualsTo tenantEqualsTo = new EqualsTo();
                    tenantEqualsTo.setLeftExpression(new Column(dataPermissionProperties.getTenantColumn()));
                    tenantEqualsTo.setRightExpression(new LongValue(dataScope.getTenant()));
                    AndExpression enterpriseExpression = new AndExpression(where, tenantEqualsTo);
    
                    // 没有部门字段, 不用继续拼接部门
                    if (ignoreWithoutField(mappedStatementId, CamelCaseUtils.toCamelCaseField(dataPermissionProperties.getDepartmentColumn()))) {
                        return enterpriseExpression;
                    }
                    
                    // 使用and进行拼接
                    AndExpression andExpression = new AndExpression();
                    andExpression.setLeftExpression(enterpriseExpression);
                    
                    // 部门查询
                    List<String> departmentList = dataScope.getDepartmentList();
                    ItemsList itemsList = new ExpressionList(departmentList.stream().map(LongValue::new).collect(Collectors.toList()));
                    InExpression inExpression = new InExpression(new Column(dataPermissionProperties.getDepartmentColumn()), itemsList);
    
                    andExpression.setRightExpression(inExpression);
                    return andExpression;
                default:
                    break;
            }
        } catch (Exception e) {
            log.error("TcaDataPermissionHandler.err", e);
        }
        return where;
    }
    
    /**
     * 当前实体无 tenant 字段, 忽略
     * @param mappedStatementId
     * @return
     *
     */
    private boolean ignoreWithoutField(String mappedStatementId, String fieldName) {
        // 获取全限定类名
        String className = mappedStatementId.substring(0, mappedStatementId.lastIndexOf("."));
    
        Map<String, Boolean> fieldExistsMap;
        Boolean exists;
        
        fieldExistsMap = classFieldExistsMap.get(className);
        if (ValidateUtils.isEmpty(fieldExistsMap)) {
            fieldExistsMap = new ConcurrentHashMap<>(4);
            exists = classFieldExists(className, fieldName);
            fieldExistsMap.put(fieldName, exists);
            classFieldExistsMap.put(className, fieldExistsMap);
        } else {
            exists = fieldExistsMap.get(fieldName);
            if (ValidateUtils.isEmpty(exists)) {
                exists = classFieldExists(className, fieldName);
                fieldExistsMap.put(fieldName, exists);
            }
        }
        
        // 存在时返回false
        return !exists;
    }
    
    /**
     * 判断class中是否有field属性
     * @param className
     * @param fieldName
     * @return
     */
    private boolean classFieldExists(String className, String fieldName) {
        // 获取类
        Class<?> clazz;
        try {
            clazz = Class.forName(className);
        } catch (ClassNotFoundException e) {
            log.error("classNotFound:", e);
            return false;
        }
    
        // 获取父接口中的泛型类 -- 实体类
        Class tClazz = (Class)(((ParameterizedType)(clazz.getGenericInterfaces()[0])).getActualTypeArguments())[0];
    
        // 判断class有没有租户字段
		MetaClass metaClass = MetaClass.forClass(tClazz, reflectorFactory);
        return metaClass.hasSetter(fieldName);
    }
    
    /**
     * 是否忽略当前方法
     * @param mappedStatementId
     * @return
     */
    private boolean methodIgnore(String mappedStatementId) {
        List<String> ignoreMethods = dataPermissionProperties.getIgnoreMethods();
        if (ValidateUtils.isNotEmpty(ignoreMethods) && ignoreMethods.contains(mappedStatementId)) {
            return true;
        }
        List<String> ignoreMethodPatterns = dataPermissionProperties.getIgnoreMethodPatterns();
        if (ValidateUtils.isNotEmpty(ignoreMethodPatterns)) {
            String[] ignoreMethodPatternArray = ignoreMethodPatterns.toArray(new String[ignoreMethodPatterns.size()]);
            if (PatternMatchUtils.simpleMatch(ignoreMethodPatternArray, mappedStatementId)) {
                return true;
            }
        }
        return false;
    }
    
    /**
     * 是否忽略当前表
     * @param mappedStatementId
     * @return
     */
    private boolean tableIgnore(String mappedStatementId) {
        String tableName = mappedStatementIdTableName.get(mappedStatementId);
        if (ValidateUtils.isEmpty(tableName)) {
            String[] split = mappedStatementId.split("\\.", -1);
            String mapperName = split[split.length - 2];
            tableName = mapperName.substring(0, mapperName.length() - 6).toUpperCase(Locale.ENGLISH);
            mappedStatementIdTableName.put(mappedStatementId, tableName);
        }
        
        return ignoreTableNames.contains(tableName);
    }
    
}
